Hierarchical Models

Matthew Talluto

07.12.2022

Precipitation-mortality relationships in Tsuga

trees = fread("exercises/data/treedata.csv")
tsuga = trees[grep("Tsuga", species_name)]
# remove NAs
# tsuga = tsuga[complete.cases(tsuga), ]
head(tsuga)
##    n died year     species_name annual_mean_temp tot_annual_pp   prior_mu
## 1: 5    3 1989 Tsuga canadensis         3.849333     1003.0000 0.02222166
## 2: 6    0 1989 Tsuga canadensis         3.452000     1076.1333 0.02222166
## 3: 3    0 1997 Tsuga canadensis         3.620000     1099.4000 0.03245619
## 4: 4    4 1989 Tsuga canadensis         4.596000      989.4667 0.02222166
## 5: 3    0 1994 Tsuga canadensis         4.244667     1116.2000 0.02816339
## 6: 7    0 2002 Tsuga canadensis         4.730000     1137.2667 0.04108793

Precipitation-mortality relationships in Tsuga: H1

Precipitation-mortality relationships in Tsuga: H2

Precipitation-mortality relationships in Tsuga: H3

Pooled model code

# compile the model (below)
tsuga_pooled = stan_model("stan/tsuga_pooled.stan")
// file: stan/tsuga_pooled.stan
data {
    // number of data points
    int <lower=0> n; 
    
    // number of trees in each plot
    int <lower=1> n_trees [n]; 

    // number died
    int <lower=0> died [n]; 
    vector [n] precip;
}
parameters {
    real a;
    real b;
}
transformed parameters {
    vector <lower=0, upper=1> [n] p;
    p = inv_logit(a + b * precip);
}
model {
    died ~ binomial(n_trees, p);
    a ~ normal(0, 10);
    b ~ normal(0, 5);
}
generated quantities {
    // we use generated quantities to keep track of log likelihood and
    // deviance, useful for model selection
    // and also to perform poserior predictive simulations
    real deviance = 0;
    vector [n] loglik;
    int ppd_died [n];
    for (i in 1:n) {
        loglik[i] = binomial_lpmf(died[i] | n_trees[i], p[i]);
        deviance += loglik[i];
        ppd_died[i] = binomial_rng(20, p[i]);
    }
    deviance = -2 * deviance;
}
## Fit the model
## note that I have rescaled precipitaiton
## additionally, we add the year variable as a factor
precip = scale(tsuga$tot_annual_pp) 
standat = with(tsuga, list(
    n = length(died),
    n_trees = n,
    died = died,
    precip = precip[,1],
    year = as.factor(year)))  
fit_pooled = sampling(tsuga_pooled, chains=4, iter=3000, refresh=0, data = standat)

Pooled models trade accuracy for precision

Pooled models trade accuracy for precision

Unpooled model: code

# compile the model (below)
tsuga_unpooled = stan_model("stan/tsuga_unpooled.stan")
// file: stan/tsuga_unpooled.stan
data {
    // same as in the pooled model
    int <lower=0> n; 
    int <lower=1> n_trees [n]; 
    int <lower=0> died [n]; 
    vector [n] precip;
    
    // grouping variables
    // year_id is an integer starting at 1 (the earliest year)
    // ending at n_groups (the latest year)
    // we use this value as an index for any group-level effects
    int <lower=1> n_groups;
    int <lower=1, upper = n_groups> year_id [n];

}
parameters {
    // one intercept per group
    vector [n_groups] a;
    real b;
}
transformed parameters {
    vector <lower=0, upper=1> [n] p;
    
    // a is different for each data point, depending on the group
    // so we need a loop to compute this
    for(i in 1:n) {
        int gid = year_id[i];
        p[i] = inv_logit(a[gid] + b * precip[i]);
    }
}
model {
    died ~ binomial(n_trees, p);
    a ~ normal(0, 10);
    b ~ normal(0, 5);
}
generated quantities {
    // we use generated quantities to keep track of log likelihood and
    // deviance, useful for model selection
    // and also to perform poserior predictive simulations
    real deviance = 0;
    vector [n] loglik;
    int ppd_died [n];
    for (i in 1:n) {
        loglik[i] = binomial_lpmf(died[i] | n_trees[i], p[i]);
        deviance += loglik[i];
        ppd_died[i] = binomial_rng(20, p[i]);
    }
    deviance = -2 * deviance;
}
# factor variables must be converted to integers for stan
standat$year_id = as.integer(standat$year)

# we also need to tell stan how many groups (i.e., years) there are
standat$n_groups = max(standat$year_id)
##      year year_int
## [1,] 1989        1
## [2,] 1994        2
## [3,] 1997        3
## [4,] 1998        4
## [5,] 2001        5
## [6,] 2002        6
fit_unpooled = sampling(tsuga_unpooled, chains=4, iter=3000, refresh=0, 
                        data = standat)

Unpooled models use less data per parameter

Unpooled models use less data per parameter

Compromise: partial pooling

Partial Pooling: Code

# compile the model (below)
tsuga_ppool = stan_model("stan/tsuga_ppool.stan")
// file: stan/tsuga_ppool.stan
data {
    // same as in the pooled model
    int <lower=0> n; 
    int <lower=1> n_trees [n]; 
    int <lower=0> died [n]; 
    vector [n] precip;
    
    // grouping variables
    int <lower=1> n_groups;
    int <lower=1, upper = n_groups> year_id [n];

}
parameters {
    // one intercept and one slope per group
    vector [n_groups] a;
    vector [n_groups] b;
    
    // hyperparameters describe higher structure
    real a_mu;
    real <lower=0> a_sig;
    real b_mu;
    real <lower=0> b_sig;

}
transformed parameters {
    vector <lower=0, upper=1> [n] p;
    
    for(i in 1:n) {
        int gid = year_id[i];
        p[i] = inv_logit(a[gid] + b[gid] * precip[i]);
    }
}
model {
    died ~ binomial(n_trees, p);
    
    // The priors are now estimated from the data
    a ~ normal(a_mu, a_sig);
    b ~ normal(b_mu, b_sig);
    
    // hyperpriors describe what we know about higher (group-level) structure
    a_mu ~ normal(0, 20);
    b_mu ~ normal(0, 20);
    // half cauchy is common for hierarchical stdev
    a_sig ~ cauchy(0, 20);
    b_sig ~ cauchy(0, 20);
    
}
generated quantities {
    // we use generated quantities to keep track of log likelihood and
    // deviance, useful for model selection
    // and also to perform poserior predictive simulations
    real deviance = 0;
    vector [n] loglik;
    int ppd_died [n];
    for (i in 1:n) {
        loglik[i] = binomial_lpmf(died[i] | n_trees[i], p[i]);
        deviance += loglik[i];
        ppd_died[i] = binomial_rng(20, p[i]);
    }
    deviance = -2 * deviance;
}
fit_ppool = sampling(tsuga_ppool, chains=4, iter=3000, refresh=0, 
                        data = standat)

Partial pooling is a compromise

Partial pooling is a compromise

Pooling comparison

Pooling comparison

When do we need hierarchical models?

Designing hierarchical models in Stan

data {
    // group-level objects
    int <lower=1> n_groups;
    int <lower=1, upper=n_groups> group_id [n];
}
parameters {
    vector [n_groups] a; 
    
    // hyperparameters
    real a_mu;
    real a_sig;
}
transformed parameters {
    pr[i] = inv_logit(a[group_id[i]]);
}
model {
    a ~ normal(a_mu, a_sig);  // hierarchical prior for a
}

Designing hierarchical models in Stan

data {
    int n; // number of data points
    int died [n]
    int N[n];
    vector [n] precip;

    // group-level objects
    int <lower=1> n_group1;
    int <lower=1, upper=n_group1> group1_id [n];

    int <lower=1> n_group2;
    int <lower=1, upper=n_group2> group2_id [n];
}
parameters {
    vector [n_group1] a1; 
    vector [n_group2] a2; 
    
    // hyperparameters
    real a1_mu;
    real <lower=0> a1_sig;
    real a2_mu;
    real <lower=0> a2_sig;
}
transformed parameters {
    vector [n] pr;
    for(i in 1:n)
        pr[i] = inv_logit(a1[group1_id[i]] + a2[group2_id[i]] + b*precip[i]);
}
model {
    died ~ binomial(N, pr); // likelihood

    a1 ~ normal(a1_mu, a1_sig);  // hierarchical prior for a1
    a2 ~ normal(a2_mu, a2_sig);  // hierarchical prior for a2

    // hyperpriors
    a1_mu ~ normal(0,10)
    a2_mu ~ normal(0,10)
    a1_sig ~ gamma(0.1, 0.1);
    a2_sig ~ gamma(0.1, 0.1);
}

Designing hierarchical models in Stan

data {
    int n; // number of data points
    int died [n]
    int N[n];
    vector [n] temperature;

    // group-level objects
    int <lower=1> n_group1;
    int <lower=1, upper=n_group1> group1_id [n];

    int <lower=1> n_group2;
    int <lower=1, upper=n_group2> group2_id [n_group1];
}
parameters {
    vector [n_group1] a1; 
    vector [n_group2] a2; 
    
    // hyperparameters
    real <lower=0> a1_sig;
    real a2_mu;
    real <lower=0> a2_sig;
}
transformed parameters {
    vector [n] pr;
    for(i in 1:n)
        pr[i] = inv_logit(a1[group1_id[i]] + b*precip[i]);
}
model {
    died ~ binomial(N, pr); // likelihood

    for(i in n_group1)
        a1 ~ normal(a2[i], a1_sig);  // hierarchical prior for a1
    // hyperpriors
    a2 ~ normal(a2_mu, a2_sig);  // hierarchical prior for a2
    a1_sig ~ gamma(0.1, 0.1);
    
    // hyperhyperprior
    a2_mu ~ normal(0,10)
}

Posterior predictive distributions

sim1 = function(amu, asig, bmu, bsig, N, precip) {
    a = rnorm(length(precip), amu, asig)
    b = rnorm(length(precip), bmu, bsig)
    p = plogis(a + b*precip)
    rbinom(length(precip), N, p)
}

Posterior predictive distributions

newx = seq(min(standat$precip), max(standat$precip), length.out=400)
pars = data.frame(as.matrix(fit_ppool, pars=c("a_mu", "a_sig", "b_mu", "b_sig")))

# For our hypothetical, we need to decide how many trees we would see
# more trees means less sampling uncertainty
N = 20
sims = mapply(sim1, amu = pars$a_mu, asig = pars$a_sig,
              bmu = pars$b_mu, bsig = pars$b_sig, 
              MoreArgs = list(N = 20, precip = newx))
sim_quantiles = apply(sims, 1, quantile, c(0.5, 0.05, 0.95))